# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

expect_as_vector <- function(x, y, ...) {
  expect_equal(as.vector(x), y, ...)
}

# expect both objects to contain equal values when converted to data.frame objects
expect_equal_data_frame <- function(x, y, ...) {
  expect_equal(as.data.frame(x), as.data.frame(y), ...)
}

expect_r6_class <- function(object, class) {
  expect_s3_class(object, class)
  expect_s3_class(object, "R6")
}

#' Mask `testthat::expect_equal()` in order to compare ArrowObjects using their
#' `Equals` methods from the C++ library.
expect_equal <- function(object, expected, ignore_attr = FALSE, ..., info = NULL, label = NULL) {
  if (inherits(object, "ArrowObject") && inherits(expected, "ArrowObject")) {
    mc <- match.call()
    expect_true(
      all.equal(object, expected, check.attributes = !ignore_attr),
      info = info,
      label = paste(rlang::as_label(mc[["object"]]), "==", rlang::as_label(mc[["expected"]]))
    )
  } else {
    testthat::expect_equal(object, expected, ignore_attr = ignore_attr, ..., info = info, label = label)
  }
}

expect_type_equal <- function(object, expected, ...) {
  if (is.Array(object)) {
    object <- object$type
  }
  if (is.Array(expected)) {
    expected <- expected$type
  }
  expect_equal(object, expected, ...)
}

expect_match_arg_error <- function(object, values = c()) {
  expect_error(object, paste0("'arg' .*", paste(dQuote(values), collapse = ", ")))
}

expect_deprecated <- expect_warning

verify_output <- function(...) {
  if (isTRUE(grepl("conda", R.Version()$platform))) {
    skip("On conda")
  }
  testthat::verify_output(...)
}

#' Ensure that dplyr methods on Arrow objects return the same as for data frames
#'
#' This function compares the output of running a dplyr expression on a tibble
#' or data.frame object against the output of the same expression run on a Table
#'
#' @param expr A dplyr pipeline which must have `.input` as its start
#' @param tbl A tibble or data.frame which will be substituted for `.input`
#' @param warning The expected warning from Arrow evaluation
#'  path, passed to `expect_warning()`. Special values:
#'     * `NA` (the default) for ensuring no warning message
#'     * `TRUE` is a special case to mean to check for the
#'      "not supported in Arrow; pulling data into R" message.
#' @param ... additional arguments, passed to `expect_equal()`
compare_dplyr_binding <- function(expr, tbl, warning = NA, ...) {
  # Quote the contents of `expr` so that we can evaluate it twice
  expr <- rlang::enquo(expr)
  # Get the expected output by evaluating expr on the .input data.frame using regular dplyr
  expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = tbl)))

  if (isTRUE(warning)) {
    # Special-case the simple warning:
    warning <- "> Pulling data into R"
  }

  # Evaluate `expr` on a Table object and compare with `expected`
  expect_warning(
    via_table <- rlang::eval_tidy(
      expr,
      rlang::new_data_mask(rlang::env(.input = arrow_table(tbl)))
    ),
    warning
  )
  expect_equal(via_table, expected, ...)
}

#' Assert that Arrow dplyr methods error in the same way as methods on data.frame
#'
#' Comparing the error message generated when running expressions on R objects
#' against the error message generated by running the same expression on Arrow
#' Tables and RecordBatches.
#'
#' @param expr A dplyr pipeline which must have `.input` as its start
#' @param tbl A tibble or data.frame which will be substituted for `.input`
#' @param ... additional arguments, passed to `expect_error()`
compare_dplyr_error <- function(expr, tbl, ...) {
  # ensure we have supplied tbl
  force(tbl)

  expr <- rlang::enquo(expr)
  msg <- tryCatch(
    rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = tbl))),
    error = function(e) {
      msg <- conditionMessage(e)

      if (grepl("Problem while computing", msg[1])) {
        msg <- conditionMessage(e$parent)
      }

      # The error here is of the form:
      #
      # Problem with `filter()` .input `..1`.
      # x object 'b_var' not found
      # ℹ Input `..1` is `chr == b_var`.
      #
      # but what we really care about is the `x` block
      # so (temporarily) let's pull those blocks out when we find them
      pattern <- i18ize_error_messages()

      if (grepl(pattern, msg)) {
        msg <- sub(paste0("^.*(", pattern, ").*$"), "\\1", msg)
      }
      msg
    }
  )
  # make sure msg is a character object (i.e. there has been an error)
  # If it did not error, we would get a data.frame or whatever
  # This expectation will tell us "dplyr on data.frame errored is not TRUE"
  expect_true(identical(typeof(msg), "character"), label = "dplyr on data.frame errored")

  expect_error(
    rlang::eval_tidy(
      expr,
      rlang::new_data_mask(rlang::env(.input = record_batch(tbl)))
    ),
    msg,
    ...
  )
  expect_error(
    rlang::eval_tidy(
      expr,
      rlang::new_data_mask(rlang::env(.input = arrow_table(tbl)))
    ),
    msg,
    ...
  )
}

#' Comparing the output of running expressions on R vectors against the same
#' expression run on Arrow Arrays and ChunkedArrays.
#'
#' @param expr A vectorized R expression which must have `.input` as its start
#' @param vec A vector which will be substituted for `.input`
#' @param skip_array The skip message to show (if you should skip the Array test)
#' @param skip_chunked_array The skip message to show (if you should skip the ChunkedArray test)
#' @param ignore_attr Ignore differences in specified attributes?
#' @param ... additional arguments, passed to `expect_as_vector()`
compare_expression <- function(expr,
                               vec,
                               skip_array = NULL,
                               skip_chunked_array = NULL,
                               ignore_attr = FALSE,
                               ...) {
  expr <- rlang::enquo(expr)
  expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = vec)))
  skip_msg <- NULL

  if (is.null(skip_array)) {
    via_array <- rlang::eval_tidy(
      expr,
      rlang::new_data_mask(rlang::env(.input = Array$create(vec)))
    )
    expect_as_vector(via_array, expected, ignore_attr, ...)
  } else {
    skip_msg <- c(skip_msg, skip_array)
  }

  if (is.null(skip_chunked_array)) {
    # split input vector into two to exercise ChunkedArray with >1 chunk
    split_vector <- split_vector_as_list(vec)

    via_chunked <- rlang::eval_tidy(
      expr,
      rlang::new_data_mask(rlang::env(.input = ChunkedArray$create(split_vector[[1]], split_vector[[2]])))
    )
    expect_as_vector(via_chunked, expected, ignore_attr, ...)
  } else {
    skip_msg <- c(skip_msg, skip_chunked_array)
  }

  if (!is.null(skip_msg)) {
    skip(paste(skip_msg, collapse = "\n"))
  }
}

#' Comparing the error message generated when running expressions on R objects
#' against the error message generated by running the same expression on Arrow
#' Arrays and ChunkedArrays.
#'
#' @param expr An R expression which must have `.input` as its start
#' @param vec A vector which will be substituted for `.input`
#' @param skip_array The skip message to show (if you should skip the Array test)
#' @param skip_chunked_array The skip message to show (if you should skip the ChunkedArray test)
#' @param ... additional arguments, passed to `expect_error()`
compare_expression_error <- function(expr,
                                     vec,
                                     skip_array = NULL,
                                     skip_chunked_array = NULL,
                                     ...) {
  expr <- rlang::enquo(expr)

  msg <- tryCatch(
    rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = vec))),
    error = function(e) {
      msg <- conditionMessage(e)

      pattern <- i18ize_error_messages()

      if (grepl(pattern, msg)) {
        msg <- sub(paste0("^.*(", pattern, ").*$"), "\\1", msg)
      }
      msg
    }
  )

  expect_true(identical(typeof(msg), "character"), label = "vector errored")

  skip_msg <- NULL

  if (is.null(skip_array)) {
    expect_error(
      rlang::eval_tidy(
        expr,
        rlang::new_data_mask(rlang::env(.input = Array$create(vec)))
      ),
      msg,
      ...
    )
  } else {
    skip_msg <- c(skip_msg, skip_array)
  }

  if (is.null(skip_chunked_array)) {
    # split input vector into two to exercise ChunkedArray with >1 chunk
    split_vector <- split_vector_as_list(vec)

    expect_error(
      rlang::eval_tidy(
        expr,
        rlang::new_data_mask(rlang::env(.input = ChunkedArray$create(split_vector[[1]], split_vector[[2]])))
      ),
      msg,
      ...
    )
  } else {
    skip_msg <- c(skip_msg, skip_chunked_array)
  }

  if (!is.null(skip_msg)) {
    skip(paste(skip_msg, collapse = "\n"))
  }
}

split_vector_as_list <- function(vec) {
  vec_split <- length(vec) %/% 2
  vec1 <- vec[seq(from = min(1, length(vec) - 1), to = min(length(vec) - 1, vec_split), by = 1)]
  vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)]
  list(vec1, vec2)
}

expect_across_equal <- function(across_expr, expected, tbl) {
  expect_identical(expand_across(as_adq(tbl), across_expr), new_quosures(expected))
}

expect_arrow_eval_error <- function(expr, ..., .data = example_data) {
  mask <- arrow_mask(as_adq(.data))
  expect_error(arrow_eval({{ expr }}, mask), ...)
}
